import torch
import lietorch
import droid_backends
import droid_backends_bamf
from lietorch import SE3, Sim3
from torch.multiprocessing import Value

from modules.droid_net import cvx_upsample
import geom.projective_ops as pops
from geom.ba import JDSA
from geom.ba import get_preint_factors, get_bias_factors, get_bias_prior_factors, velo_retr, bias_retr
from pgo_buffer import global_relative_posesim3_constraints
import numpy as np
from imu import IMUIntegrator
import time
class DepthVideo:
    def __init__(self, config, args, image_size, buffer):
        self.IMU_initialized = False
        self.args = args
        # current keyframe count
        self.counter = Value('i', 0)
        self.ht = ht = image_size[0]
        self.wd = wd = image_size[1]
        self.is_initialized = False
        self.config = config
        self.disable_mono = config['Tracking']['disable_mono']
        self.imu_late_init_from = config['Tracking']['frontend']['imu_late_init_from']
        print('self.disable_mono', self.disable_mono)
        self.IMU_poseinit_after = args.IMU_poseinit_after
        ### state attributes ###
        self.tstamp = torch.zeros(buffer, device="cuda", dtype=torch.float).share_memory_()
        self.images = torch.zeros(buffer, 3, ht, wd, device="cpu", dtype=torch.uint8)
        self.dirty = torch.zeros(buffer, device="cuda", dtype=torch.bool).share_memory_()
        self.poses = torch.zeros(buffer, 7, device="cuda", dtype=torch.float).share_memory_()
        self.poses_sim3 = torch.zeros(buffer, 8, device="cuda", dtype=torch.float).share_memory_()
        self.disps = torch.ones(buffer, ht//8, wd//8, device="cuda", dtype=torch.float).share_memory_()
        self.disps_up = torch.zeros(buffer, ht, wd, device="cpu", dtype=torch.float).share_memory_() #temporary to cuda
        self.disps_prior = torch.zeros(buffer, ht//8, wd//8, device="cuda", dtype=torch.float).share_memory_()
        self.disps_prior_up = torch.zeros(buffer, ht, wd, device="cpu", dtype=torch.float).share_memory_()
        self.intrinsics = torch.zeros(buffer, 4, device="cuda", dtype=torch.float).share_memory_()
        self.normals = torch.zeros(buffer, 3, ht, wd, device="cpu", dtype=torch.float)
        
        ### feature attributes ###
        self.fmaps = torch.zeros(buffer, 1, 128, ht//8, wd//8, dtype=torch.half, device="cuda").share_memory_()
        self.nets = torch.zeros(buffer, 128, ht//8, wd//8, dtype=torch.half, device="cuda").share_memory_()
        self.inps = torch.zeros(buffer, 128, ht//8, wd//8, dtype=torch.half, device="cuda").share_memory_()

        # initialize poses to identity transformation
        self.poses[:] = torch.as_tensor([0, 0, 0, 0, 0, 0, 1], dtype=torch.float, device="cuda")
        self.poses_sim3[:] = torch.as_tensor([0, 0, 0, 0, 0, 0, 1, 1], dtype=torch.float, device="cuda")

        # depth prior scale
        self.dscales = torch.ones(buffer, 2, 2, device='cuda', dtype=torch.float).share_memory_()
        
        ### IMU states ###
        self.imus = args.imus
        self.Rwg = None
        self.init_g = args.init_g
        self.preints = {}
        self.kf_stamps = {}
        self.velos_w = torch.zeros(buffer, 3, device='cuda', dtype=torch.float)
        self.biass_w = torch.tensor(np.concatenate([args.init_bg, args.init_ba]), dtype=torch.float, device='cuda').repeat(buffer, 1)
        self.Tcb = args.Tcb

    def get_lock(self):
        return self.counter.get_lock()

    #TODO: might also need to scale other metrics
    def rescale(self, s, t1):
        print('rescaling to', s)
        self.poses[:t1,:3] *= s
        self.velos_w[:t1] *= s
        self.disps[:t1] /= s
        self.disps_up[:t1] /= s
        self.dscales[:t1] /= s
        # pose graph
        # Only begin PGO add_rel_posesafter IMU init
        if hasattr(self, 'pgobuf') and self.pgobuf is not None:
            # self.pgobuf.rel_covs[:, :3] *= s*s
            # self.pgobuf.rel_poses[:,:3] *= s

            poses = SE3(self.poses[:t1][None])
            rel_poses = poses[:, self.pgobuf.rel_jj[: self.pgobuf.rel_N.value]] * poses[:, self.pgobuf.rel_ii[: self.pgobuf.rel_N.value]].inv()
            prev_norm = torch.linalg.norm(self.pgobuf.rel_poses[: self.pgobuf.rel_N.value][:, :3], dim=1)
            cur_norm = torch.linalg.norm(rel_poses.data[0][:, :3].cpu(), dim=1)
            rel_scale = cur_norm / prev_norm
            self.pgobuf.rel_poses[: self.pgobuf.rel_N.value] = rel_poses.data[0].cpu()
            self.pgobuf.rel_covs[:self.pgobuf.rel_N.value, :3] *= (rel_scale*rel_scale).unsqueeze(1)
            
        # gaussian, but anyway we build from beginning now
        if hasattr(self, 'gs') and self.gs is not None:
            self.gs.rescale(s)
            

    def rm_and_reintegrate(self, index):
        self.kf_stamps[index] = self.kf_stamps[index+1]
        self.velos_w[index] = self.velos_w[index+1]
        self.biass_w[index] = self.biass_w[index+1]
        if self.imus is not None:
            del self.preints[(index, index+1)]
            self.__preintegrate(index)
        
    def reintegrate_all(self):
        self.preints = {}
        for i in range(1, self.counter.value):
            self.__preintegrate(i)
            
    def __preintegrate(self, index):
        if index < 1:
            return
        prev_stamp = self.kf_stamps[index-1]
        curr_stamp = self.kf_stamps[index]
        # print("- preintegrate from {}-{} to {}-{} dT {}".format(index-1, prev_stamp, index, curr_stamp, curr_stamp - prev_stamp))
        # assert (curr_stamp - prev_stamp) < 3.0 # sec
        
        measurements = []
        for m in self.imus:
            stamp = m[0]
            if stamp > prev_stamp and stamp < curr_stamp:
                measurements.append(m)
        inter = IMUIntegrator(prev_stamp, curr_stamp, self.biass_w[index-1, :3].cpu().numpy(), self.biass_w[index-1, 3:].cpu().numpy(), self.init_g, self.args, self.config)
        inter.integrate(measurements)
        self.preints[(index-1,index)] = inter
    
    def init_next_pose(self, index, use_uncer=False):
        import torch
        from scipy.spatial.transform import Rotation as R

        def pose_to_SE3(pose):
            """Convert 7D [t, q] to 4x4 SE3 matrix."""
            t = pose[:3].cpu().numpy()
            q = pose[3:].cpu().numpy()
            R_mat = R.from_quat(q).as_matrix()
            T = torch.eye(4)
            T[:3, :3] = torch.tensor(R_mat)
            T[:3, 3] = torch.tensor(t)
            return T

        def SE3_to_pose(T):
            """Convert 4x4 SE3 to 7D pose [t, q]."""
            R_mat = T[:3, :3].cpu().numpy()
            t = T[:3, 3].cpu().numpy()
            q = R.from_matrix(R_mat).as_quat()
            return torch.cat([torch.tensor(t, dtype=T.dtype), torch.tensor(q, dtype=T.dtype)])

        # Previous camera pose and velocity
        pose_prev = self.poses[index - 1]
        vel_prev = self.velos_w[index - 1]
        T_cam_prev = pose_to_SE3(pose_prev).cuda()

        # IMU preintegration delta
        preint = self.preints[(index - 1, index)]
        dP = torch.tensor(preint.get_updated_dP(), dtype=pose_prev.dtype, device=pose_prev.device)
        dV = torch.tensor(preint.get_updated_dV(), dtype=pose_prev.dtype, device=pose_prev.device)
        dR_log = torch.tensor(preint.get_updated_dR().log(), dtype=pose_prev.dtype, device=pose_prev.device)
        
        if use_uncer:
            # usually only position have large uncertainty
            cov = preint.cov[:9, :9]
            # cov_rot = cov[:3, :3]
            # cov_vel = cov[3:6, 3:6]
            cov_pos = cov[6:9, 6:9]
            # Compute trace
            # trace_rot = np.trace(cov_rot)
            # trace_vel = np.trace(cov_vel)
            trace_pos = np.trace(cov_pos)
            weight_rot = 1.0
            weight_vel = 1.0
            # weight_pos = 1.0
            weight_pos = 1.0 if trace_pos < 1e-4 else 0.0 
            
            # Downweight each component separately
            dP *= weight_pos
            dV *= weight_vel
            dR_log *= weight_rot
        
        dT = preint.dT
        g = torch.tensor(self.Rwg @ preint.g, dtype=pose_prev.dtype, device=pose_prev.device)
        # Rotation increment
        dR_mat = R.from_rotvec(dR_log.cpu().numpy()).as_matrix()
        dR_mat = torch.tensor(dR_mat, dtype=pose_prev.dtype, device=pose_prev.device)

        # Tcb: body (IMU) to camera , invert it to get Tbc
        T_cb = self.Tcb.matrix().squeeze(0).squeeze(0)           # (4, 4)
        T_bc = torch.linalg.inv(T_cb)

        # Compute pose of IMU at next frame
        T_imu_prev = torch.linalg.inv(T_cam_prev) @ T_cb   # world to imu
        R_imu_prev = T_imu_prev[:3, :3]
        p_imu_prev = T_imu_prev[:3, 3]

        # New imu rotation
        R_imu_next = R_imu_prev @ dR_mat

        # # New imu velocity
        vel_imu_prev = vel_prev
        vel_imu_next = vel_imu_prev + g * dT + R_imu_prev @ dV
        
        # New imu position
        if trace_pos < 1e-4:
            p_imu_next = p_imu_prev + vel_imu_prev * dT + 0.5 * g * dT * dT + R_imu_prev @ dP
        else:
            p_imu_next = p_imu_prev

        # Assemble T_imu_next
        T_imu_next = torch.eye(4, dtype=pose_prev.dtype, device=pose_prev.device)
        T_imu_next[:3, :3] = R_imu_next
        T_imu_next[:3, 3] = p_imu_next

        # Now transform back to camera
        T_cam_next = T_imu_next @ T_bc
        T_cam_next = torch.linalg.inv(T_cam_next)  # we store T_cw

        # Convert back to pose format
        pose_next = SE3_to_pose(T_cam_next).cuda()
        self.poses[index] = pose_next
        self.velos_w[index] = vel_imu_next

    @torch.cuda.amp.autocast(enabled=False)
    def __item_setter(self, index, item):
        if isinstance(index, int) and index >= self.counter.value:
            self.counter.value = index + 1
        
        elif isinstance(index, torch.Tensor) and index.max().item() > self.counter.value:
            self.counter.value = index.max().item() + 1

        self.tstamp[index] = item[0]
        
        # imu related update
        if len(item) > 10 and item[10] is not None:
            self.kf_stamps[index] = item[10]
            if self.imus is not None:
                self.__preintegrate(index)
                if index>=self.IMU_poseinit_after:
                    self.init_next_pose(index, use_uncer=True)
                    
        self.images[index] = item[1]

        if item[2] is not None:
            self.poses[index] = item[2]

        if item[3] is not None:
            self.disps[index] = item[3]

        if item[4] is not None:
            self.disps_prior_up[index] = 1.0/item[4]
            depth = item[4][3::8,3::8]
            self.disps_prior[index] = torch.where(depth>0, 1.0/depth, 0).cuda()

        if item[5] is not None:
            self.normals[index] = item[5]

        if item[6] is not None:
            self.intrinsics[index] = item[6]
        else:
            self.intrinsics[index] = self.intrinsics[0].clone()

        if len(item) > 7:
            self.fmaps[index] = item[7]

        if len(item) > 8:
            self.nets[index] = item[8]

        if len(item) > 9:
            self.inps[index] = item[9]
        

    def __setitem__(self, index, item):
        with self.get_lock():
            self.__item_setter(index, item)

    def __getitem__(self, index):
        """ index the depth video """

        with self.get_lock():
            # support negative indexing
            if isinstance(index, int) and index < 0:
                index = self.counter.value + index

            item = (
                self.poses[index],
                self.disps[index],
                self.intrinsics[index],
                self.fmaps[index],
                self.nets[index],
                self.inps[index])

        return item

    def append(self, *item):
        with self.get_lock():
            self.__item_setter(self.counter.value, item)

    @staticmethod
    def format_indicies(ii, jj):
        """ to device, long, {-1} """

        if not isinstance(ii, torch.Tensor):
            ii = torch.as_tensor(ii)

        if not isinstance(jj, torch.Tensor):
            jj = torch.as_tensor(jj)

        ii = ii.to(device="cuda", dtype=torch.long).reshape(-1)
        jj = jj.to(device="cuda", dtype=torch.long).reshape(-1)

        return ii, jj

    def upsample(self, ix, mask):
        """ upsample disparity """

        disps_up = cvx_upsample(self.disps[ix].unsqueeze(-1), mask)
        self.disps_up[ix] = disps_up.squeeze().cpu()

    def normalize(self, enforce_scale=None):
        """ normalize depth and poses """

        with self.get_lock():
            if enforce_scale is not None:
                s = enforce_scale
            else:    
                s = self.disps[:self.counter.value].mean().item() * self.config['Dataset']['scale_multiplier']
            
            print(f"Normalize pose and depth by {s:.3f}")
            self.poses[:self.counter.value,:3] *= s
            self.disps[:self.counter.value] /= s
            self.disps_up[:self.counter.value] /= s
            self.dscales[:self.counter.value] /= s
            self.dirty[:self.counter.value] = True

    def reproject(self, ii, jj, sim3=False):
        """ project points from ii -> jj """
        ii, jj = DepthVideo.format_indicies(ii, jj)
        Gs = Sim3(self.poses_sim3[None]) if sim3 else SE3(self.poses[None])

        coords, valid_mask = \
            pops.projective_transform(Gs, self.disps[None], self.intrinsics[None], ii, jj)

        return coords, valid_mask

    def distance(self, ii=None, jj=None, beta=0.3, bidirectional=True):
        """ frame distance metric """

        return_matrix = False
        if ii is None:
            return_matrix = True
            N = self.counter.value
            ii, jj = torch.meshgrid(torch.arange(N), torch.arange(N), indexing='ij')
        
        ii, jj = DepthVideo.format_indicies(ii, jj)

        if bidirectional:

            poses = self.poses[:self.counter.value].clone()

            d1 = droid_backends.frame_distance(
                poses, self.disps, self.intrinsics[0], ii, jj, beta)

            d2 = droid_backends.frame_distance(
                poses, self.disps, self.intrinsics[0], jj, ii, beta)

            d = .5 * (d1 + d2)

        else:
            d = droid_backends.frame_distance(
                self.poses, self.disps, self.intrinsics[0], ii, jj, beta)

        if return_matrix:
            return d.reshape(N, N)

        return d
        
    def distance_covis(self, ii=None):
        """ frame distance metric based on covisibility """
        ii = torch.as_tensor(ii)
        ii = ii.to(device="cuda", dtype=torch.long).reshape(-1)
        poses = self.poses[:self.counter.value].clone()
        d = droid_backends.covis_distance(poses, self.disps, self.intrinsics[0], ii)
        d = d * (1. / self.disps[ii].median())
        return d

    def cuda_ba(self, target, weight, eta, ii, jj, t0=1, t1=None, itrs=2, lm=1e-4, ep=0.1, motion_only=False, use_mono=False):
        with self.get_lock():

            # [t0, t1] window of bundle adjustment optimization
            if t1 is None:
                t1 = max(ii.max().item(), jj.max().item()) + 1

            ht, wd = self.disps.shape[1:]
            target = target.view(-1, ht, wd, 2).permute(0,3,1,2).contiguous()
            weight = weight.view(-1, ht, wd, 2).permute(0,3,1,2).contiguous()

            dx, dz, dzcov = droid_backends.ba(self.poses, self.disps, self.intrinsics[0], target, weight, eta, ii, jj, t0, t1, itrs, lm, ep, motion_only, False)

            if (not self.disable_mono) and use_mono:
                poses = lietorch.SE3(self.poses[:t1][None])
                disps = self.disps[:t1][None]
                dscales = self.dscales[:t1]
                disps, dscales, _ = JDSA(target, weight, eta, poses, disps, self.intrinsics[None], self.disps_prior, dscales, ii, jj, self.mono_depth_alpha)
                self.disps[:t1] = disps[0]
                self.dscales[:t1] = dscales

            self.disps.clamp_(min=0.001)
            

    def inertial_ba(self, target, weight, eta, ii, jj, t0=1, t1=None, itrs=2, lm=1e-4, ep=0.1, use_mono=False):
        """ inertial dense bundle adjustment (DBA) """
        verbose = False
        preint_scale = 1e-5
        with self.get_lock():
            # [t0, t1] window of bundle adjustment optimization
            if t1 is None:
                t1 = max(ii.max().item(), jj.max().item()) + 1
            
            ht, wd = self.disps.shape[1:]
            target = target.view(-1, ht, wd, 2).permute(0,3,1,2).contiguous()
            weight = weight.view(-1, ht, wd, 2).permute(0,3,1,2).contiguous()

            iii = torch.arange(t0-1, t1-1, device='cuda') 
            jjj = iii + 1

            if verbose:
                print("="*100)
                print("Cuda inertial BA from frame '{}' to '{}' size {}".format(t0, t1, ii.shape[0]))
            
            # in imu/body frame
            poses_bw = self.Tcb.inv() * SE3(self.poses[:t1][None])
            velos_w = self.velos_w[:t1].unsqueeze(0)
            biass_w = self.biass_w[:t1].unsqueeze(0)
            for _ in range(itrs):
                # i->j,   j in camera frame, i in body frame
                Gibj = self.Tcb * poses_bw[:,jj] * poses_bw[:,ii].inv()
                Gij = Gibj * self.Tcb.inv()
                
                Hintii, Hintij, Hintji, Hintjj, vinti, vintj, chi2 = get_preint_factors(poses_bw, velos_w, biass_w, self.preints, self.Rwg, iii, jjj, preint_scale=preint_scale)
                if verbose:
                    print("- - Chi2 error preint: {:.5f}".format(torch.sum(chi2).item()))

                Hbii, Hbij, Hbji, Hbjj, vbi, vbj = get_bias_factors(biass_w, self.preints, iii, jjj, preint_scale=preint_scale)
                Hbpii, vbpi = get_bias_prior_factors(biass_w, iii, preint_scale=preint_scale)
                Hint = torch.cat([Hintii, Hintij, Hintji, Hintjj, Hbii, Hbij, Hbji, Hbjj, Hbpii, Hbpii])    # 10xNx15x15
                vint = torch.cat([vinti, vintj, vbi, vbj, vbpi])                                            # 5xNx15

                fake_intrinsics=torch.zeros((1, 8), device=self.intrinsics.device, dtype=self.intrinsics.dtype)
                fake_intrinsics[:, :4] = self.intrinsics[0, :4]
                dx = droid_backends_bamf.inertial_ba(poses_bw.data[0], self.disps, fake_intrinsics, Gij.data[0], Gibj.data[0],
                                                self.Tcb.data[0,0], Hint, vint, target, weight, eta, ii, jj, t0, t1, 1, lm, ep)

                velos_w = velo_retr(velos_w, dx[None, :, 6:9], torch.arange(t1-t0) + t0)
                biass_w = bias_retr(biass_w, dx[None, :, 9:], torch.arange(t1-t0) + t0)

            self.poses[:t1] = (self.Tcb * poses_bw).data[0]
            self.velos_w[:t1] = velos_w[0]
            self.biass_w[:t1] = biass_w[0]
            
            if (not self.disable_mono) and use_mono:
                poses = lietorch.SE3(self.poses[:t1][None])
                disps = self.disps[:t1][None]
                dscales = self.dscales[:t1]
                disps, dscales, _ = JDSA(target, weight, eta, poses, disps, self.intrinsics[None], self.disps_prior, dscales, ii, jj, self.mono_depth_alpha)
                self.disps[:t1] = disps[0]
                self.dscales[:t1] = dscales
            
            self.disps.clamp_(min=0.001, max=10)
            # print("self.biass_w[:t1] shape", self.biass_w[:t1].shape)
            # if t1 % 100 == 0:
            #     print("biass_w.shape:", biass_w.shape)
            #     print("biass_w:", self.biass_w[:t1])
                # print("additional self.biass_w (10 more):", self.biass_w[t1:t1+10])
          
    def cuda_pgba(self, target, weight, eta, ii, jj, t0=1, t1=None, itrs=2, lm=1e-4, ep=0.1, verbose=False, se3=False):
        from geom.ba import pose_retr
        poses = Sim3(self.poses_sim3[:t1][None])
        
        # rel pose constraints
        rel_N = self.pgobuf.rel_N.value
        iip, jjp = self.pgobuf.rel_ii[:rel_N].cuda(), self.pgobuf.rel_jj[:rel_N].cuda()
        rel_poses = self.pgobuf.rel_poses[:rel_N].cuda()[None]
        infos = 1 / self.pgobuf.rel_covs[:rel_N].cuda()
        infos = torch.cat((infos, infos.min(dim=1, keepdim=True)[0]), dim=1)
        infos = infos.unsqueeze(2).expand(*infos.size(), infos.shape[-1]) * torch.eye(infos.shape[-1], device='cuda')[None]
        infos[torch.isnan(infos) | torch.isinf(infos)] = 0.
        
        for _ in range(itrs):
            Hsp, vsp, pchi2, pchi2_scaled, pr = global_relative_posesim3_constraints(iip, jjp, poses, rel_poses, infos, pw=1e-3)
            debug = False
            if debug:
                pgba_info_dict = {
                    'ii': ii,
                    'jj': jj,
                    'iip': iip,
                    'jjp': jjp,
                    'infos': infos,
                    'rel_poses': rel_poses,
                    'rel_N': rel_N,
                    'pr': pr,
                }
                import os
                os.makedirs(os.path.join(self.args.output, 'pgba'), exist_ok=True)
                save_path = os.path.join(self.args.output, 'pgba', f'pgba_info_dict_{t0:03d}_{t1:03d}_iter{_:01d}.pth')
                if not os.path.exists(save_path):
                    torch.save(pgba_info_dict, save_path)
                    print(f"Saved pgba info dict to {save_path}")
            
            disps = self.disps[:t1][None]

            if verbose:
                coords, valid = pops.projective_transform(poses, disps, self.intrinsics[None], ii, jj)
                r = (target - coords).view(1, ii.shape[0], -1, 1)
                rw = .001 * (valid * weight).view(1, ii.shape[0], -1, 1)
                rchi2 = torch.sum((rw * r).transpose(2,3) @ r)
                print("- Chi2 error reproj: {:.5f} relpose: {:.5f} {:.5f}".format(rchi2.item(), pchi2.item(), pchi2_scaled.item()))

            B, P, ht, wd = disps.shape
            N = ii.shape[0]
            D = poses.manifold_dim

            ### 1: commpute jacobians and residuals ###
            coords, valid, (Ji, Jj, Jz) = pops.projective_transform(
                poses, disps, self.intrinsics[None], ii, jj, jacobian=True)

            r = (target - coords).view(B, N, -1, 1)
            w = .001 * (valid * weight).view(B, N, -1, 1)

            ### 2: construct linear system ###
            Ji = Ji.reshape(B, N, -1, D)
            Jj = Jj.reshape(B, N, -1, D)
            wJiT = (w * Ji).transpose(2,3)
            wJjT = (w * Jj).transpose(2,3)

            Jz = Jz.reshape(B, N, ht*wd, -1)

            Hii = torch.matmul(wJiT, Ji)
            Hij = torch.matmul(wJiT, Jj)
            Hji = torch.matmul(wJjT, Ji)
            Hjj = torch.matmul(wJjT, Jj)
            Hs = torch.cat((Hii, Hij, Hji, Hjj))

            vi = torch.matmul(wJiT, r).squeeze(-1)
            vj = torch.matmul(wJjT, r).squeeze(-1)
            vs = torch.cat((vi, vj))

            Ei = (wJiT.view(B,N,D,ht*wd,-1) * Jz[:,:,None]).sum(dim=-1)
            Ej = (wJjT.view(B,N,D,ht*wd,-1) * Jz[:,:,None]).sum(dim=-1)

            w = w.view(B, N, ht*wd, -1)
            r = r.view(B, N, ht*wd, -1)
            wk = torch.sum(w*r*Jz, dim=-1)
            Ck = torch.sum(w*Jz*Jz, dim=-1)

            # disable scale if se3
            if se3:
                Hsp[:, :, :, -1, :] = 0
                Hsp[:, :, :, :, -1] = 0
                vsp[:, :, :, -1] = 0
                Hs[:, :, -1, :] = 0
                Hs[:, :, :, -1] = 0
                vs[:, :, -1] = 0
                Ei[:, :, -1, :] = 0
                Ej[:, :, -1, :] = 0
            
            
            dx, dz = droid_backends.pgba(poses.data[0], self.disps, eta,
                                Hs, vs, Ei[0], Ej[0], Ck[0], wk[0],
                                Hsp, vsp, ii, jj, iip, jjp, t0, t1, lm, ep)

            poses = pose_retr(poses, dx[None], torch.arange(t0, t1))

        self.poses_sim3[:t1] = poses.data
        self.disps.clamp_(min=0.001, max=10)
